import torch
import torch.nn as nn

from modeling_t5 import VLT5


class VLT5MMT(VLT5):
    def __init__(self, config):
        super().__init__(config)

    def _replace_token(self, inputs, masking_indices, mask_index, vocab_size):
        inputs[masking_indices] = mask_index
        return inputs

    def _mask_tokens_by_word(self, inputs, augmentation_masking_probability):
        vocab_size = 32100
        eos_index = 1
        pad_index, unk_index = 0, 2
        available_token_indices = (inputs != eos_index) & (inputs != pad_index) & (inputs != unk_index)
        random_masking_indices = torch.bernoulli(
            torch.full(inputs.shape, augmentation_masking_probability, device=inputs.device)).bool()
        masked_inputs = inputs.clone()
        masking_indices = random_masking_indices & available_token_indices
        masked_inputs = self._replace_token(masked_inputs, masking_indices, unk_index, vocab_size)
        return masked_inputs

    def train_step(self, batch):
        RD = False
        AG = False
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)
        vis_attention_mask = batch['vis_attention_mask'].to(device)
        lm_labels = batch["target_ids"].to(device)
        decoder_input_ids = None
        augmentation_masking_probability = 0.1
        if RD:
            vis_feats = torch.cat([batch['vis_feats'], batch['vis_feats'].clone()], 0).to(device)
            input_ids = torch.cat([batch['input_ids'], batch['input_ids'].clone()], 0).to(device)
            vis_pos = torch.cat([batch['boxes'], batch['boxes'].clone()], 0).to(device)
            vis_attention_mask = torch.cat([batch['vis_attention_mask'], batch['vis_attention_mask'].clone()], 0).to(
                device)
            lm_labels = torch.cat([batch['target_ids'], batch['target_ids'].clone()], 0).to(device)
        if AG:
            input_ids_1 = self._mask_tokens_by_word(batch['input_ids'].clone(), augmentation_masking_probability)
            input_ids = torch.cat([batch['input_ids'], input_ids_1], 0).to(device)
            decoder_input_id_labele = self._shift_right(batch['target_ids'])
            decoder_input_id = self._mask_tokens_by_word(decoder_input_id_labele.clone(),
                                                         augmentation_masking_probability)
            decoder_input_ids = torch.cat([decoder_input_id_labele, decoder_input_id.clone()], 0).to(device)

        output = self(
            input_ids=input_ids,
            vis_inputs=(vis_feats, vis_pos),
            vis_attention_mask=vis_attention_mask,
            labels=lm_labels,
            decoder_input_ids=decoder_input_ids,
            reduce_loss=True,
            return_dict=True,
            RD=RD,
            AG=AG,
            epoch=int(batch['epoch'])
        )

        loss = output['loss']

        result = {
            'loss': loss
        }
        return result

    def test_step(self, batch, **kwargs):
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)

        vis_attention_mask = batch['vis_attention_mask'].to(device)

        output = self.generate(
            input_ids=input_ids,
            vis_inputs=(vis_feats, vis_pos),
            vis_attention_mask=vis_attention_mask,
            **kwargs
        )

        generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)

        result = {}
        result['pred'] = generated_sents

        return result


from modeling_bart import VLBart


class VLBartMMT(VLBart):
    def __init__(self, config):
        super().__init__(config)

    def train_step(self, batch):
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)

        vis_attention_mask = batch['vis_attention_mask'].to(device)

        lm_labels = batch["target_ids"].to(device)

        output = self(
            input_ids=input_ids,
            vis_inputs=(vis_feats, vis_pos),
            vis_attention_mask=vis_attention_mask,
            labels=lm_labels,
            reduce_loss=True,
            return_dict=True
        )

        loss = output['loss']

        result = {
            'loss': loss
        }
        return result

    def test_step(self, batch, **kwargs):
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)

        vis_attention_mask = batch['vis_attention_mask'].to(device)

        output = self.generate(
            input_ids=input_ids,
            vis_inputs=(vis_feats, vis_pos),
            vis_attention_mask=vis_attention_mask,
            **kwargs
        )

        generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)

        result = {}
        result['pred'] = generated_sents

        return result
